package com.flow.framework.core.system.checker.impl;

import com.flow.framework.common.health.ServiceHealthCheckCode;
import com.flow.framework.common.util.verify.VerifyUtil;
import com.flow.framework.core.system.checker.AbstractSystemStatusChecker;
import com.flow.framework.core.system.thread.pool.bo.TaskTimeoutBo;
import com.flow.framework.core.system.thread.pool.task.BaseCallable;
import com.flow.framework.core.system.thread.pool.task.BaseRunnable;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.MDC;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

/**
 * 任务超时检查器
 *
 * @author luoguopiao
 * @version 0.0.1
 * @date 2022/1/23
 */
@Slf4j
public class ThreadPoolStatusChecker extends AbstractSystemStatusChecker implements com.flow.framework.core.system.checker.ThreadPoolStatusChecker {

    /**
     * 日志记录间隔时间，避免日志记录过多
     */
    private static final long LOG_INTERVAL = 60000;

    private final ListenedTargetHolder listenedTargetHolder = new ListenedTargetHolder();

    /**
     * 上一次记录日志的时间
     */
    private final AtomicLong lastLogTime = new AtomicLong(-1);

    /**
     * @inheritDoc
     */
    @Override
    public void onSubmit(Future<?> future, BaseRunnable baseRunnable) {
        try {
            TaskTimeoutBo taskTimeoutBo = new TaskTimeoutBo(future, System.currentTimeMillis(), baseRunnable.getTimeout(),
                    baseRunnable.getClass(), baseRunnable.getMdcContext());
            listenedTargetHolder.addTaskTimeoutDomainToListen(taskTimeoutBo);
        } catch (Exception e) {
            log.error("listen submit task error.", e);
        }
    }

    /**
     * @inheritDoc
     */
    @Override
    public <V> void onSubmit(Future<?> future, BaseCallable<V> baseCallable) {
        try {
            TaskTimeoutBo taskTimeoutBo = new TaskTimeoutBo(future, System.currentTimeMillis(), baseCallable.getTimeout(),
                    baseCallable.getClass(), baseCallable.getMdcContext());
            listenedTargetHolder.addTaskTimeoutDomainToListen(taskTimeoutBo);
        } catch (Exception e) {
            log.error("listen submit task error.", e);
        }
    }

    private Set<String> check(Map<Future<?>, TaskTimeoutBo> futureTaskTimeoutDomainMap) {
        long logStartTime = System.currentTimeMillis();
        boolean isLog = logStartTime - lastLogTime.get() > LOG_INTERVAL;
        Set<String> timeoutTaskClazzNames = new HashSet<>();

        List<TaskTimeoutBo> taskTimeoutBos = new LinkedList<>(futureTaskTimeoutDomainMap.values());
        for (TaskTimeoutBo taskTimeoutBo : taskTimeoutBos) {
            try {
                MDC.setContextMap(taskTimeoutBo.getMdcContext());
                Future<?> future = taskTimeoutBo.getFuture();
                boolean isOk = future.isDone() || future.isCancelled();
                if (!isOk) {
                    long costTime = System.currentTimeMillis() - taskTimeoutBo.getStartTime();
                    boolean isTimeout = costTime > taskTimeoutBo.getTimeout();
                    if (isTimeout) {
                        String clazzName = taskTimeoutBo.getTaskClazz().getSimpleName();
                        timeoutTaskClazzNames.add(clazzName);
                        if (isLog) {
                            lastLogTime.getAndSet(logStartTime);
                            log.error("task timeout error, task clazz : {}, cost time {}  !", clazzName, costTime);
                        }
                    }
                } else {
                    listenedTargetHolder.removeTaskTimeoutDomain(taskTimeoutBo);
                }
            } catch (Exception e) {
                log.error("timeout task listener error.", e);
            } finally {
                MDC.clear();
            }
        }
        return timeoutTaskClazzNames;
    }

    @Override
    protected Set<String> executeAsyncHealthCheck() {
        Set<String> unhealthyTags = new HashSet<>();
        try {
            Set<String> clazzNames1 = check(listenedTargetHolder.getListenerDataCache1());
            Set<String> clazzNames2 = check(listenedTargetHolder.getListenerDataCache2());
            if (!VerifyUtil.isEmpty(clazzNames1)) {
                unhealthyTags.addAll(clazzNames1);
            }
            if (!VerifyUtil.isEmpty(clazzNames2)) {
                unhealthyTags.addAll(clazzNames2);
            }
        } catch (Exception e) {
            log.error("listen task timeout error.", e);
        }
        return unhealthyTags;
    }

    @Override
    public int getServiceHealthCheckCode() {
        return ServiceHealthCheckCode.SERVICE_THREAD_POOL_CODE;
    }

    @Override
    public long getInitialDelay() {
        return 0;
    }

    @Override
    public long getPeriod() {
        return 5000;
    }

    private static class ListenedTargetHolder {

        private final Lock lock = new ReentrantLock();

        /**
         * 用于存放任务时创建的任务超时监控biz object，这里采用两个map存放的原因是如果线程池队列最大为Integer.MAX，
         * 则此时缓存map的数据大概率会大于Integer.MAX，因为部分已经处理完的任务可能还没有移除，所以需要两个map缓存数据
         */
        private final Map<Future<?>, TaskTimeoutBo> futureAndTaskTimeoutDomainMap1 = new ConcurrentHashMap<>();

        private final Map<Future<?>, TaskTimeoutBo> futureAndTaskTimeoutDomainMap2 = new ConcurrentHashMap<>();


        private void addTaskTimeoutDomainToListen(TaskTimeoutBo taskTimeoutBo) {
            String clazzName = taskTimeoutBo.getTaskClazz().getSimpleName();
            lock.lock();
            try {
                if (futureAndTaskTimeoutDomainMap1.size() != Integer.MAX_VALUE) {
                    futureAndTaskTimeoutDomainMap1.put(taskTimeoutBo.getFuture(), taskTimeoutBo);
                } else {
                    if (futureAndTaskTimeoutDomainMap2.size() != Integer.MAX_VALUE) {
                        futureAndTaskTimeoutDomainMap2.put(taskTimeoutBo.getFuture(), taskTimeoutBo);
                    } else {
                        log.error("put task error. clazz name: {}", clazzName);
                    }
                }
            } finally {
                lock.unlock();
            }
        }

        private void removeTaskTimeoutDomain(TaskTimeoutBo taskTimeoutBo) {
            Future<?> future = taskTimeoutBo.getFuture();
            if (null == futureAndTaskTimeoutDomainMap1.remove(future)) {
                futureAndTaskTimeoutDomainMap2.remove(future);
            }
        }

        private Map<Future<?>, TaskTimeoutBo> getListenerDataCache1() {
            return futureAndTaskTimeoutDomainMap1;
        }

        private Map<Future<?>, TaskTimeoutBo> getListenerDataCache2() {
            return futureAndTaskTimeoutDomainMap2;
        }
    }
}
